import os

from PIL import Image


def get_file_paths(folder):
    image_file_paths = []
    for root, dirs, filenames in os.walk(folder):
        filenames = sorted(filenames)
        for filename in filenames:
            input_path = os.path.abspath(root)
            file_path = os.path.join(input_path, filename)
            if filename.endswith('.png') or filename.endswith('.jpg'):
                image_file_paths.append(file_path)

        break  # prevent descending into subfolders
    return image_file_paths


def unalign_images(file_paths, target_path, max_num=1000):
    ta_path = target_path + 'A'
    tb_path = target_path + 'B'
    if not os.path.exists(ta_path):
        os.makedirs(ta_path)
    if not os.path.exists(tb_path):
        os.makedirs(tb_path)

    for i in range(len(file_paths)):
        img = Image.open(file_paths[i]).convert('RGB')
        # split AB image into A and B
        w, h = img.size
        w2 = int(w / 2)
        img_a = img.crop((0, 0, w2, h))
        img_b = img.crop((w2, 0, w, h))
        assert(img_a.size == img_b.size)
        basename = os.path.basename(file_paths[i].split('.')[0])
        img_a.save(os.path.join(ta_path, '%s.jpg'%(basename)), format='JPEG', subsampling=0, quality=100)
        img_b.save(os.path.join(tb_path, '%s.jpg'%(basename)), format='JPEG', subsampling=0, quality=100)
        if i == max_num-1:
            break

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--path',
        help='Which folder to process (it should have subfolders test, train)'
    )
    args = parser.parse_args()

    dataset_folder = args.path
    print(dataset_folder)

    test_path = os.path.join(dataset_folder, 'test')
    test_file_paths = get_file_paths(test_path)
    train_path = os.path.join(dataset_folder, 'train')
    train_file_paths = get_file_paths(train_path)

    unalign_images(test_file_paths, 'mnist_cd2cb/test', max_num=500)
    unalign_images(train_file_paths, 'mnist_cd2cb/train', max_num=1000)
